- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 129
permit NNlibCUDA to use Float16 #363
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
| _batched_mul!(::Type{DT}, C, A, B, α::Number, β::Number) where {DT<:DenseArray{T}} where {T<:BlasFloat} = | ||
| _batched_try_gemm!(DT, C, A, B, α, β) | ||
|  | ||
| function _batched_try_gemm!(::Type{DT}, C, A, B, α::Number, β::Number) where {DT<:DenseArray{T}} where {T<:BlasFloat} | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My concern with this change (removing {T<:BlasFloat} restriction, not highlighed well) is that it may send weird numbers (like Dual, or BigFloat) down the path towards batched_gemm! which won't accept them.
Perhaps, to safely widen here, the method _batched_gemm!(::Type{<:Array} below needs to be restricted to Array{<:BlasFloat}? With a new method offering another path to batched_mul_generic! at that stage?
The dispatch in this file is pretty convoluted! Maybe there's another tidier solution.
Float16 would be good to have, though. Thanks for digging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the only place this method (ie _batched_try_gemm!) is currently called is from the method immediately above (ie _batched_mul!() where {T<:BlasFloat}).  widening _batched_try_gemm! to types other than BlasFloat permits the proposed new _batched_mul!() where {T<:Float16} in FluxML/NNlibCUDA.jl#32 to call it too.  i don't think there's any danger of weird number types getting where they shouldn't.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, now I see better what you're proposing. There are two jumps to the CUDA package, in order to allow Float16 only for CuArrays, not for Arrays. Which is the desired behaviour. The first jump comes back to this package's chain of functions.
It does seem slightly weird to jump twice. Let me think a bit more, I'd be happier if there was exactly one point in the chain where dispatch cared about CuArrays.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ping
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I dropped the ball here. I think we should do this, or at least I certainly didn't get around to thinking up a better way.
Could you perhaps add some comments explaining a bit what's going on? Having dispatch at two points, instead of just reading down the page & at some point jumping to CUDA, is one step trickier to read. Maybe the  where {DT<:DenseArray{T}} where {T<:BlasFloat} = ... method can explain that there's another path through here for CuArray{Float16}?
in conjunction with FluxML/NNlibCUDA.jl#32, add support for half-precision
gemm, for which a special kernel is provided by Nvidia. see JuliaGPU/CUDA.jl#1080